Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixes NAG optimizer #15543 #16053

Merged
merged 3 commits into from
Sep 11, 2019
Merged

Fixes NAG optimizer #15543 #16053

merged 3 commits into from
Sep 11, 2019

Conversation

anirudhacharya
Copy link
Member

@anirudhacharya anirudhacharya commented Aug 31, 2019

Description

Fixes #15543

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • fix update rule

For review - @zhanghang1989 @apeforest @eric-haibin-lin

@zhanghang1989
Copy link
Contributor

image

mom = state
mom[:] *= self.momentum
weight[:] += lr * self.momentum * mom
weight[:] -= lr * (1 + self.self.momentum) * grad
mom[:] -=  grad

@wkcn
Copy link
Member

wkcn commented Sep 4, 2019

Hi @zhanghang1989 , is there any difference between a -= b and a[:] -= b?
a[:] -= b may call extra function __getitem__.

import mxnet as mx
import time

T = 1000
N = 1000


while 1:
    ti = time.time()
    a = mx.nd.arange(N)
    for i in range(T):
        a += 1
    mx.nd.waitall()
    print('a += b: ', time.time() - ti)

    ti = time.time()
    a = mx.nd.arange(N)
    for i in range(T):
        a[:] += 1
    mx.nd.waitall()
    print('a[:] += b: ', time.time() - ti)

Output:

a += b:  0.06155872344970703
a[:] += b:  0.3492248058319092
a += b:  0.06215381622314453
a[:] += b:  0.30852508544921875
a += b:  0.07872796058654785
a[:] += b:  0.31493425369262695
a += b:  0.08103752136230469
a[:] += b:  0.3226127624511719
a += b:  0.05706977844238281
a[:] += b:  0.29704785346984863

@anirudhacharya
Copy link
Member Author

image

mom = state
mom[:] *= self.momentum
weight[:] += lr * self.momentum * mom
weight[:] -= lr * (1 + self.self.momentum) * grad
mom[:] -=  grad

@zhanghang1989 The update rule in this PR is the following -

mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]
                              +(param_momentum+1)*(mom_data[i]
                                -(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i]))));

this update rule is same as the following psuedocode -

weight = (weight - momentum * mom) + (momentum+1)*(momentum * mom - lr*(grad + wd*weight))

which when simplified, translates to

weight[:] += (momentum**2 * mom) - (momentum + 1) * lr * (grad + wd*weight)

Formula -
image

( it is the same rule used in keras as well - https://stats.stackexchange.com/questions/179915/whats-the-difference-between-momentum-based-gradient-descent-and-nesterovs-acc)

Copy link
Contributor

@zhanghang1989 zhanghang1989 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight update is correct. Please fix the mentum update in the end.

@anirudhacharya
Copy link
Member Author

The weight update is correct. Please fix the mentum update in the end.

yes, will change the momentum update state

@zhanghang1989
Copy link
Contributor

Hi @zhanghang1989 , is there any difference between a -= b and a[:] -= b?
a[:] -= b may call extra function __getitem__.

import mxnet as mx
import time

T = 1000
N = 1000


while 1:
    ti = time.time()
    a = mx.nd.arange(N)
    for i in range(T):
        a += 1
    mx.nd.waitall()
    print('a += b: ', time.time() - ti)

    ti = time.time()
    a = mx.nd.arange(N)
    for i in range(T):
        a[:] += 1
    mx.nd.waitall()
    print('a[:] += b: ', time.time() - ti)

Output:

a += b:  0.06155872344970703
a[:] += b:  0.3492248058319092
a += b:  0.06215381622314453
a[:] += b:  0.30852508544921875
a += b:  0.07872796058654785
a[:] += b:  0.31493425369262695
a += b:  0.08103752136230469
a[:] += b:  0.3226127624511719
a += b:  0.05706977844238281
a[:] += b:  0.29704785346984863

I am not familiar with symbol API. Just write some pseudocode to show how NAG works :)

@eric-haibin-lin eric-haibin-lin changed the title Fixes #15543 Fixes NAG optimizer #15543 Sep 5, 2019
@eric-haibin-lin
Copy link
Member

Thanks @zhanghang1989 and @anirudhacharya

@Vikas-kum
Copy link
Contributor

larroy pushed a commit to larroy/mxnet that referenced this pull request Sep 28, 2019
* fix update rules

* readable updates in unit test

* mom update
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in NAG Optimizer
5 participants